import torch.nn as nn
import torch

input_dim = 1
output_dim = 1


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out
model = LinearRegressionModel(input_dim, output_dim)

model.load_state_dict(torch.load("demo2model.pkl",weights_only=False))

# 测试
print(model(torch.tensor([[5.0]])))


